{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "K8S_PROXY_ADDR='127.0.0.1:8001'\n",
    "K8S_NAMESPACE='mdt'\n",
    "APP_NAME='tf-classification'\n",
    "MODEL_NAME='clipper-tf-predict'\n",
    "MODEL_FILE='model.ckpt'\n",
    "REPO_URL='658391232643.dkr.ecr.us-west-2.amazonaws.com'\n",
    "VERSION = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the Model and persist it into a pickle file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.contrib.tensor_forest.python import tensor_forest\n",
    "from tensorflow.python.ops import resources\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# Ignore all GPUs, tf random forest does not benefit from it.\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
    "\n",
    "data = load_iris()\n",
    "dX, dy = data[\"data\"], data[\"target\"]\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    dX, dy, test_size=0.33, random_state=42)\n",
    "\n",
    "# Parameters\n",
    "num_steps = 500  # Total steps to train\n",
    "batch_size = 10  # The number of samples per batch\n",
    "num_classes = 3  # The 10 digits\n",
    "num_features = 4  # Each image is 28x28 pixels\n",
    "num_trees = 10\n",
    "max_nodes = 100\n",
    "\n",
    "# Input and Target data\n",
    "X = tf.placeholder(tf.float32, shape=[None, num_features])\n",
    "# For random forest, labels must be integers (the class id)\n",
    "Y = tf.placeholder(tf.int32, shape=[None])\n",
    "\n",
    "# Random Forest Parameters\n",
    "hparams = tensor_forest.ForestHParams(num_classes=num_classes,\n",
    "                                      num_features=num_features,\n",
    "                                      num_trees=num_trees,\n",
    "                                      max_nodes=max_nodes).fill()\n",
    "\n",
    "# Build the Random Forest\n",
    "forest_graph = tensor_forest.RandomForestGraphs(hparams)\n",
    "# Get training graph and loss\n",
    "train_op = forest_graph.training_graph(X, Y)\n",
    "loss_op = forest_graph.training_loss(X, Y)\n",
    "\n",
    "# Measure the accuracy\n",
    "infer_op, _, _ = forest_graph.inference_graph(X)\n",
    "correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
    "accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "# Initialize the variables (i.e. assign their default value) and forest resources\n",
    "init_vars = tf.group(tf.global_variables_initializer(),\n",
    "                     resources.initialize_resources(resources.shared_resources()))\n",
    "\n",
    "\n",
    "def next_batch(size):\n",
    "    index = range(len(X_train))\n",
    "    index_batch = np.random.choice(index, size)\n",
    "    return X_train[index_batch], y_train[index_batch]\n",
    "\n",
    "\n",
    "# Start TensorFlow session\n",
    "sess = tf.Session()\n",
    "\n",
    "# Run the initializer\n",
    "sess.run(init_vars)\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "# Training\n",
    "for i in range(1, num_steps + 1):\n",
    "    # Prepare Data\n",
    "    # Get the next batch of MNIST data (only images are needed, not labels)\n",
    "    batch_x, batch_y = next_batch(batch_size)\n",
    "    _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})\n",
    "    if i % 50 == 0 or i == 1:\n",
    "        acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})\n",
    "        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))\n",
    "# Test Model\n",
    "print(\"Test Accuracy:\", sess.run(\n",
    "    accuracy_op, feed_dict={X: X_test, Y: y_test}))\n",
    "\n",
    "# Print the tensors related to this model\n",
    "print(accuracy_op)\n",
    "print(infer_op)\n",
    "print(X)\n",
    "print(Y)\n",
    "\n",
    "# save the model to a check point file\n",
    "save_path = saver.save(sess, \"./model.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Connect to clipper and register App"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from clipper_admin import ClipperConnection, KubernetesContainerManager\n",
    "manager = KubernetesContainerManager(\n",
    "    kubernetes_proxy_addr=K8S_PROXY_ADDR, namespace=K8S_NAMESPACE)\n",
    "clipper_conn = ClipperConnection(manager)\n",
    "clipper_conn.connect()\n",
    "\n",
    "clipper_conn.register_application(\n",
    "   name = APP_NAME, input_type = 'doubles', default_output = '0', slo_micros = 100000000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a prediction function by loading the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "graph = tf.get_default_graph()\n",
    "saver = tf.train.import_meta_graph(\n",
    "    'model.ckpt.meta')\n",
    "saver.restore(sess, 'model.ckpt')\n",
    "load_infer_op = graph.get_tensor_by_name('probabilities:0')\n",
    "accuracy_op = graph.get_tensor_by_name('Mean_1:0')\n",
    "oX = graph.get_tensor_by_name('Placeholder:0')\n",
    "oY = graph.get_tensor_by_name('Placeholder_1:0')\n",
    "\n",
    "def predict(X):\n",
    "    print(\"inputs {}\".format(X))\n",
    "    result = sess.run(load_infer_op, feed_dict={oX: X})\n",
    "    ret = [str(i) for i in result]\n",
    "    print(\"return is {}\".format(ret))\n",
    "    return ret\n",
    "\n",
    "print(predict([[5.9, 3.0, 5.1, 1.8]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Failed to pickle using clipper python deploy function, this is a known issue for pickling cython class. https://stackoverflow.com/questions/12646436/pickle-cython-class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from clipper_admin.deployers import python as python_deployer\n",
    "python_deployer.deploy_python_closure(clipper_conn,\n",
    "                                      name=MODEL_NAME,\n",
    "                                      version=VERSION,\n",
    "                                      input_type=\"doubles\",\n",
    "                                      func=predict,\n",
    "                                      registry=REPO_URL,\n",
    "                                      pkgs_to_install=['tensorflow'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Failed to deploy using tensorflow as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from clipper_admin.deployers.tensorflow import deploy_tensorflow_model\n",
    "\n",
    "deploy_tensorflow_model(clipper_conn,\n",
    "                        name=MODEL_NAME,\n",
    "                        version=VERSION,\n",
    "                        input_type=\"doubles\",\n",
    "                        func=predict,\n",
    "                        tf_sess_or_saved_model_path=sess,\n",
    "                        registry=REPO_URL,\n",
    "                        pkgs_to_install=['tensorflow'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Link the model to application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clipper_conn.link_model_to_app(app_name=APP_NAME, model_name=MODEL_NAME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "test it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests, json, numpy as np\n",
    "def predict():\n",
    "    headers = {\"Content-type\": \"application/json\"}\n",
    "    data=json.dumps({\"input\": [5.9, 3.0, 5.1, 1.8]})\n",
    "    print(data)\n",
    "    url = \"http://{}/api/v1/namespaces/{}/services/query-frontend-at-default-cluster:1337/proxy/{}/predict\".format(K8S_PROXY_ADDR,K8S_NAMESPACE,APP_NAME)\n",
    "    res = requests.post(url, headers=headers, data=data)\n",
    "    print(res.json())\n",
    "\n",
    "predict() "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
